{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# Pytorch-Struct\n",
    "\n",
    "[![Build Status](https://travis-ci.org/harvardnlp/pytorch-struct.svg?branch=master)](https://travis-ci.org/harvardnlp/pytorch-struct)\n",
    "[![Coverage Status](https://coveralls.io/repos/github/harvardnlp/pytorch-struct/badge.svg?branch=master)](https://coveralls.io/github/harvardnlp/pytorch-struct?branch=master)\n",
    "\n",
    "<p align=\"center\">\n",
    "  <img src=\"https://github.com/harvardnlp/pytorch-struct/raw/master/download.png\">\n",
    "  </p>\n",
    "\n",
    "\n",
    "\n",
    "A library of tested, GPU implementations of core structured prediction algorithms for deep learning applications.\n",
    "(or an implementation of <a href=\"https://www.cs.jhu.edu/~jason/papers/eisner.spnlp16.pdf\">Inside-Outside and Forward-Backward Algorithms Are Just Backprop\"<a/>)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Getting Started"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -qU git+https://github.com/harvardnlp/pytorch-struct\n",
    "!pip install -q matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch_struct import DependencyCRF, LinearChainCRF\n",
    "import matplotlib.pyplot as plt\n",
    "def show(x): plt.imshow(x.detach())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAJu0lEQVR4nO3dzYtdBx2H8efrTF9MKlobQUyizUKUIEjrUKsFF63gKxbBRYUKusnGai2Ctt34D4joQoRQddNiF7ELkfqGL4ib4DQt1iQWQqttasW0opaKTdL+XMwIMWlyz9zc45n58XygkPuS0y9hnpx7z9whqSok9fGqqQdIWiyjlpoxaqkZo5aaMWqpmeUxDrrj9Uv15t2LP/SLI12pH+OoL5MRjgqvGmUtHH/0ilGOq3H8mxc4WS++4hfZKFG/efcyv/7RGxd+3CdOv7TwYwK8NEKAz7986cKPCbA9p0Y57pf3vHuU42ocB+vn533Ml99SM0YtNWPUUjNGLTVj1FIzRi01MyjqJB9M8liSY0nuHHuUpPnNjDrJEvBN4EPAXuCTSfaOPUzSfIacqa8DjlXV41V1ErgfuHncWZLmNSTqncBTZ9w+vn7f/0iyL8lqktVnnxvnk1+SZlvYhbKq2l9VK1W1suOqpUUdVtIGDYn6aWD3Gbd3rd8naRMaEvVvgbcm2ZPkUuAW4AfjzpI0r5k/pVVVp5PcBvwEWAK+U1WHR18maS6DfvSyqh4EHhx5i6QF8BNlUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSMzOjTrI7yS+THElyOMnt/49hkuazPOA5p4EvVtWhJK8BHkrys6o6MvI2SXOYeaauqmeq6tD6r58HjgI7xx4maT4bek+d5GrgGuDgKzy2L8lqktVnn3tpMeskbdjgqJNcAXwf+EJV/fPsx6tqf1WtVNXKjquWFrlR0gYMijrJJawFfV9VPTDuJEkXY8jV7wDfBo5W1dfGnyTpYgw5U98AfAq4Mckj6/99eORdkuY081taVfUbIP+HLZIWwE+USc0YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUzOCokywleTjJD8ccJOnibORMfTtwdKwhkhZjUNRJdgEfAe4Zd46kizX0TP114EvAy+d7QpJ9SVaTrD773EsLGSdp42ZGneSjwF+r6qELPa+q9lfVSlWt7LhqaWEDJW3MkDP1DcDHkvwRuB+4Mcm9o66SNLeZUVfVXVW1q6quBm4BflFVt46+TNJc/D611MzyRp5cVb8CfjXKEkkL4ZlaasaopWaMWmrGqKVmjFpqZkNXv4c69rvtfHzXdWMcWiO549g4P6uzctnfRjnuv6oWfsx7//6uhR8T4O4djy38mNd94F/nfcwztdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11MygqJO8LsmBJH9IcjTJe8YeJmk+Q/8p228AP66qTyS5FNg24iZJF2Fm1EleC7wP+DRAVZ0ETo47S9K8hrz83gOcAL6b5OEk9yTZfvaTkuxLsppk9RQvLnyopGGGRL0MXAt8q6quAV4A7jz7SVW1v6pWqmrlEi5b8ExJQw2J+jhwvKoOrt8+wFrkkjahmVFX1V+Ap5K8bf2um4Ajo66SNLehV78/B9y3fuX7ceAz402SdDEGRV1VjwArI2+RtAB+okxqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWYGRZ3kjiSHk/w+yfeSXD72MEnzmRl1kp3A54GVqnoHsATcMvYwSfMZ+vJ7GXh1kmVgG/Dn8SZJuhgzo66qp4GvAk8CzwD/qKqfnv28JPuSrCZZPcWLi18qaZAhL7+vBG4G9gBvArYnufXs51XV/qpaqaqVS7hs8UslDTLk5ff7gSeq6kRVnQIeAN477ixJ8xoS9ZPA9Um2JQlwE3B03FmS5jXkPfVB4ABwCHh0/ffsH3mXpDktD3lSVX0F+MrIWyQtgJ8ok5oxaqkZo5aaMWqpGaOWmjFqqRmjlpoxaqkZo5aaMWqpGaOWmjFqqRmjlpoxaqkZo5aaMWqpGaOWmjFqqRmjlpoxaqkZo5aaMWqpGaOWmjFqqRmjlpoxaqkZo5aaSVUt/qDJCeBPA566A3h24QPGs5X2bqWtsLX2boatb6mqN7zSA6NEPVSS1apamWzABm2lvVtpK2ytvZt9qy+/pWaMWmpm6qi32j9ev5X2bqWtsLX2buqtk76nlrR4U5+pJS2YUUvNTBZ1kg8meSzJsSR3TrVjliS7k/wyyZEkh5PcPvWmIZIsJXk4yQ+n3nIhSV6X5ECSPyQ5muQ9U2+6kCR3rH8d/D7J95JcPvWms00SdZIl4JvAh4C9wCeT7J1iywCngS9W1V7geuCzm3jrmW4Hjk49YoBvAD+uqrcD72QTb06yE/g8sFJV7wCWgFumXXWuqc7U1wHHqurxqjoJ3A/cPNGWC6qqZ6rq0Pqvn2fti27ntKsuLMku4CPAPVNvuZAkrwXeB3wboKpOVtXfp1010zLw6iTLwDbgzxPvOcdUUe8Enjrj9nE2eSgASa4GrgEOTrtkpq8DXwJennrIDHuAE8B3198q3JNk+9Sjzqeqnga+CjwJPAP8o6p+Ou2qc3mhbKAkVwDfB75QVf+ces/5JPko8NeqemjqLQMsA9cC36qqa4AXgM18feVK1l5R7gHeBGxPcuu0q841VdRPA7vPuL1r/b5NKcklrAV9X1U9MPWeGW4APpbkj6y9rbkxyb3TTjqv48DxqvrvK58DrEW+Wb0feKKqTlTVKeAB4L0TbzrHVFH/Fnhrkj1JLmXtYsMPJtpyQUnC2nu+o1X1tan3zFJVd1XVrqq6mrU/119U1aY7mwBU1V+Ap5K8bf2um4AjE06a5Ung+iTb1r8ubmITXthbnuJ/WlWnk9wG/IS1K4jfqarDU2wZ4AbgU8CjSR5Zv+/uqnpwwk2dfA64b/0v98eBz0y857yq6mCSA8Ah1r4r8jCb8COjfkxUasYLZVIzRi01Y9RSM0YtNWPUUjNGLTVj1FIz/wHW6B6zE6OeNgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Make some data.\n",
    "vals = torch.zeros(2, 10, 10) + 1e-5\n",
    "vals[:, :5, :5] = torch.rand(5)\n",
    "vals[:, 5:, 5:] = torch.rand(5) \n",
    "dist = DependencyCRF(vals.log())\n",
    "show(dist.log_potentials[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAALDklEQVR4nO3dz49V9RnH8c+HmeF3q4gslCFlFkZDjC1mYlWMCzFWq5GFXUAiSd2wsCoaE6Pd+A+o0aTGhKAulOACSDTGCK0/FppIHX74A0YTAhYGsU5tVKCRGeDpYqYJBYZ75s75emaevF+JCXPv9fEJ8ubce+fMuY4IAchjWtMLAKgXUQPJEDWQDFEDyRA1kExniaGXXtIRixd11T7382Pza58pSZ3/rv/vtmmnynxXodT3Knz0P4Umo4SfdFxDccLnu69I1IsXdenvWxfVPveqD1bXPlOSFmycXfvMGd8P1z5Tkk53nPf/44R1/W1HkbkoY3u8M+Z9PP0GkiFqIBmiBpIhaiAZogaSIWogmUpR277d9pe299l+vPRSANrXMmrbHZKel3SHpCWSVtleUnoxAO2pcqS+TtK+iNgfEUOSXpO0ouxaANpVJeqFkg6d8fXA6G3/x/Ya2322+wa/O1XXfgDGqbY3yiJiXUT0RkTvgvkddY0FME5Voj4s6cwTubtHbwMwCVWJ+mNJV9jusT1d0kpJb5RdC0C7Wv6UVkSctP2ApK2SOiS9FBF7im8GoC2VfvQyIt6S9FbhXQDUgDPKgGSIGkiGqIFkiBpIhqiBZIpceHDPPxfomqfvr33utJm1j5QkDdx2svaZPZvLXCBw+nc/FZnLJ6rlwZEaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkimyNVEu46d1mUfHqt97oEVc2qfKUnzPqn/87QHlpf5jO6e14eLzEUeHKmBZIgaSIaogWSIGkiGqIFkiBpIhqiBZFpGbXuR7fds77W9x/ban2MxAO2pcvLJSUmPRsRO27+QtMP2XyNib+HdALSh5ZE6Io5ExM7RXx+V1C9pYenFALRnXK+pbS+WtFTS9vPct8Z2n+2+4eHj9WwHYNwqR217rqTNkh6OiB/Pvj8i1kVEb0T0dnWVOUcbQGuVorbdpZGgN0TElrIrAZiIKu9+W9KLkvoj4pnyKwGYiCpH6mWSVku6xfbu0X9+X3gvAG1q+S2tiPhAkn+GXQDUgDPKgGSIGkiGqIFkiBpIpsiFBxUhnzhV+9jud4ZqnylJB+6p/yKBswbK/NYeWTa3yNzLPioyFg3gSA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJFPkkpfRMU1D82fWPnfaqah9piR1b6v/U4UGVx2tfaYkfXrTK0Xm/u7p3xSZi58fR2ogGaIGkiFqIBmiBpIhaiAZogaSIWogmcpR2+6wvcv2myUXAjAx4zlSr5XUX2oRAPWoFLXtbkl3Slpfdh0AE1X1SP2spMcknR7rAbbX2O6z3Tc8fLyW5QCMX8uobd8l6duI2HGhx0XEuojojYjerq45tS0IYHyqHKmXSbrb9leSXpN0i+1Xi24FoG0to46IJyKiOyIWS1op6d2IuLf4ZgDawvepgWTG9fPUEfG+pPeLbAKgFhypgWSIGkiGqIFkiBpIhqiBZBxR/xU6f+lL4rdeXvtclLP1691F5l71weoicxdsnF37zBnfD9c+U5JOd9R/tdodH/1FR38cOO9gjtRAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDLj+iwt5HXN0/cXmTttZpGxGrjtZO0zezbXf9VPSZr+3U+1z/Sp02Pex5EaSIaogWSIGkiGqIFkiBpIhqiBZIgaSKZS1LYvtr3J9he2+23fUHoxAO2pevLJc5Lejog/2J4uqf7PEQVQi5ZR275I0s2S/ihJETEkaajsWgDaVeXpd4+kQUkv295le73tOWc/yPYa2322+4Z1ovZFAVRTJepOSddKeiEilko6Lunxsx8UEesiojciers0o+Y1AVRVJeoBSQMRsX30600aiRzAJNQy6oj4RtIh21eO3rRc0t6iWwFoW9V3vx+UtGH0ne/9ku4rtxKAiagUdUTsltRbeBcANeCMMiAZogaSIWogGaIGkiFqIBmuJgpJ0mUfHisy98CKc84orsW8TzpqnzmwvP6ZktTz+nD9Qz32lU85UgPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDBcehCTJJ04Vmdv9zlCRuQfuqf8igbMGyuRwZNnc2mcO7x/7eMyRGkiGqIFkiBpIhqiBZIgaSIaogWSIGkimUtS2H7G9x/bntjfanll6MQDtaRm17YWSHpLUGxFXS+qQtLL0YgDaU/Xpd6ekWbY7Jc2W9HW5lQBMRMuoI+KwpKckHZR0RNIPEbHt7MfZXmO7z3bfsE7UvymASqo8/Z4naYWkHkmXS5pj+96zHxcR6yKiNyJ6uzSj/k0BVFLl6fetkg5ExGBEDEvaIunGsmsBaFeVqA9Kut72bNuWtFxSf9m1ALSrymvq7ZI2Sdop6bPRf2dd4b0AtKnSD5BGxJOSniy8C4AacEYZkAxRA8kQNZAMUQPJEDWQDFcThSRpaH6ZH7ybdiqKzO3e5tpnDq46WvtMSfr0pldqn3ndtsEx7+NIDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0k44j6r/Zoe1DSPyo89FJJ/6p9gXKm0r5TaVdpau07GXb9VUQsON8dRaKuynZfRPQ2tsA4TaV9p9Ku0tTad7LvytNvIBmiBpJpOuqp9uH1U2nfqbSrNLX2ndS7NvqaGkD9mj5SA6gZUQPJNBa17dttf2l7n+3Hm9qjFduLbL9ne6/tPbbXNr1TFbY7bO+y/WbTu1yI7Yttb7L9he1+2zc0vdOF2H5k9M/B57Y32i7zcaET0EjUtjskPS/pDklLJK2yvaSJXSo4KenRiFgi6XpJf5rEu55praT+ppeo4DlJb0fEVZJ+rUm8s+2Fkh6S1BsRV0vqkLSy2a3O1dSR+jpJ+yJif0QMSXpN0oqGdrmgiDgSETtHf31UI3/oFja71YXZ7pZ0p6T1Te9yIbYvknSzpBclKSKGIuL7ZrdqqVPSLNudkmZL+rrhfc7RVNQLJR064+sBTfJQJMn2YklLJW1vdpOWnpX0mKTTTS/SQo+kQUkvj75UWG97TtNLjSUiDkt6StJBSUck/RAR25rd6ly8UVaR7bmSNkt6OCJ+bHqfsdi+S9K3EbGj6V0q6JR0raQXImKppOOSJvP7K/M08oyyR9LlkubYvrfZrc7VVNSHJS064+vu0dsmJdtdGgl6Q0RsaXqfFpZJutv2Vxp5WXOL7VebXWlMA5IGIuJ/z3w2aSTyyepWSQciYjAihiVtkXRjwzudo6moP5Z0he0e29M18mbDGw3tckG2rZHXfP0R8UzT+7QSEU9ERHdELNbI7+u7ETHpjiaSFBHfSDpk+8rRm5ZL2tvgSq0clHS97dmjfy6WaxK+sdfZxH80Ik7afkDSVo28g/hSROxpYpcKlklaLekz27tHb/tzRLzV4E6ZPChpw+hf7vsl3dfwPmOKiO22N0naqZHviuzSJDxllNNEgWR4owxIhqiBZIgaSIaogWSIGkiGqIFkiBpI5r/xtmivwuuQjwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Compute marginals\n",
    "show(dist.marginals[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAJu0lEQVR4nO3dz69cBR2G8ef1toAFI/7a0Da2C6JpTBRzQ1ASFkAiCIGNC0gw0U03omhMDLrxHzBGF8akQd1AYFFYEEMEI7pwU7kUgrZVQlChgLGa+CMYadGvi3tNaks7507neO5883ySJr0zw/CG3Iczczr3NFWFpD7eNvUASYtl1FIzRi01Y9RSM0YtNbNtjCd977tXas/u7Qt/3uef27Hw55SW0T95nZP1Rt7qvlGi3rN7O794fPfCn/cTV3xk4c8pLaND9ZNz3ufLb6kZo5aaMWqpGaOWmjFqqRmjlpoZFHWSm5L8JskLSe4de5Sk+c2MOskK8B3gZmAfcGeSfWMPkzSfIUfqq4EXqurFqjoJPATcPu4sSfMaEvVO4OXTvj6+cdv/SLI/yVqStRN//tei9knapIWdKKuqA1W1WlWr73vPyqKeVtImDYn6FeD0D3Lv2rhN0hY0JOqngCuT7E1yEXAH8Oi4syTNa+ZPaVXVm0nuBh4HVoDvV9WR0ZdJmsugH72sqseAx0beImkB/ESZ1IxRS80YtdSMUUvNGLXUzCgXHnz+uR2jXCTw8VefXfhzghc0VC8eqaVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZka5muhYxrrq5xhXKfUKpZqKR2qpGaOWmjFqqRmjlpoxaqkZo5aaMWqpmZlRJ9md5KdJjiY5kuSe/8cwSfMZ8uGTN4EvV9XhJO8Ank7y46o6OvI2SXOYeaSuqteq6vDG7/8OHAN2jj1M0nw29THRJHuAq4BDb3HffmA/wCXsWMA0SfMYfKIsyWXAw8AXq+pvZ95fVQeqarWqVrdz8SI3StqEQVEn2c560A9U1SPjTpJ0IYac/Q7wPeBYVX1z/EmSLsSQI/W1wKeB65M8u/HrkyPvkjSnmSfKqurnQP4PWyQtgJ8ok5oxaqkZo5aaMWqpmaW68OBYxrhI4BgXMwQvaKjZPFJLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS814NdGRjHXVT69Sqlk8UkvNGLXUjFFLzRi11IxRS80YtdSMUUvNDI46yUqSZ5L8cMxBki7MZo7U9wDHxhoiaTEGRZ1kF3ALcN+4cyRdqKFH6m8BXwH+fa4HJNmfZC3J2ineWMg4SZs3M+oktwJ/rKqnz/e4qjpQVatVtbqdixc2UNLmDDlSXwvcluR3wEPA9UnuH3WVpLnNjLqqvlpVu6pqD3AH8GRV3TX6Mklz8c+ppWY29fPUVfUz4GejLJG0EB6ppWaMWmrGqKVmjFpqxqilZrya6JLxKqWaxSO11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMVxMV4FVKO/FILTVj1FIzRi01Y9RSM0YtNWPUUjNGLTUzKOoklyc5mOTXSY4l+djYwyTNZ+iHT74N/KiqPpXkImDHiJskXYCZUSd5J3Ad8BmAqjoJnBx3lqR5DXn5vRc4AfwgyTNJ7kty6ZkPSrI/yVqStVO8sfChkoYZEvU24KPAd6vqKuB14N4zH1RVB6pqtapWt3PxgmdKGmpI1MeB41V1aOPrg6xHLmkLmhl1Vf0BeDnJBzZuugE4OuoqSXMbevb788ADG2e+XwQ+O94kSRdiUNRV9SywOvIWSQvgJ8qkZoxaasaopWaMWmrGqKVmvJqoRrVsVyldFld/4h/nvM8jtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNeOFBLaWxLmi4LJ6vP5/zPo/UUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjODok7ypSRHkvwqyYNJLhl7mKT5zIw6yU7gC8BqVX0IWAHuGHuYpPkMffm9DXh7km3ADuDV8SZJuhAzo66qV4BvAC8BrwF/raonznxckv1J1pKsneKNxS+VNMiQl9/vAm4H9gJXAJcmuevMx1XVgapararV7Vy8+KWSBhny8vtG4LdVdaKqTgGPAB8fd5akeQ2J+iXgmiQ7kgS4ATg27ixJ8xrynvoQcBA4DPxy4585MPIuSXMa9PPUVfV14Osjb5G0AH6iTGrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqJlW1+CdNTgC/H/DQ9wJ/WviA8SzT3mXaCsu1dytsfX9Vve+t7hgl6qGSrFXV6mQDNmmZ9i7TVliuvVt9qy+/pWaMWmpm6qiX7S+vX6a9y7QVlmvvlt466XtqSYs39ZFa0oIZtdTMZFEnuSnJb5K8kOTeqXbMkmR3kp8mOZrkSJJ7pt40RJKVJM8k+eHUW84nyeVJDib5dZJjST429abzSfKlje+DXyV5MMklU2860yRRJ1kBvgPcDOwD7kyyb4otA7wJfLmq9gHXAJ/bwltPdw9wbOoRA3wb+FFVfRD4MFt4c5KdwBeA1ar6ELAC3DHtqrNNdaS+Gnihql6sqpPAQ8DtE205r6p6raoOb/z+76x/0+2cdtX5JdkF3ALcN/WW80nyTuA64HsAVXWyqv4y7aqZtgFvT7IN2AG8OvGes0wV9U7g5dO+Ps4WDwUgyR7gKuDQtEtm+hbwFeDfUw+ZYS9wAvjBxluF+5JcOvWoc6mqV4BvAC8BrwF/raonpl11Nk+UDZTkMuBh4ItV9bep95xLkluBP1bV01NvGWAb8FHgu1V1FfA6sJXPr7yL9VeUe4ErgEuT3DXtqrNNFfUrwO7Tvt61cduWlGQ760E/UFWPTL1nhmuB25L8jvW3NdcnuX/aSed0HDheVf995XOQ9ci3qhuB31bViao6BTwCfHziTWeZKuqngCuT7E1yEesnGx6daMt5JQnr7/mOVdU3p94zS1V9tap2VdUe1v+7PllVW+5oAlBVfwBeTvKBjZtuAI5OOGmWl4BrkuzY+L64gS14Ym/bFP/Sqnozyd3A46yfQfx+VR2ZYssA1wKfBn6Z5NmN275WVY9NuKmTzwMPbPzP/UXgsxPvOaeqOpTkIHCY9T8VeYYt+JFRPyYqNeOJMqkZo5aaMWqpGaOWmjFqqRmjlpoxaqmZ/wB4dR5RAcfWuwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Compute argmax\n",
    "show(dist.argmax.detach()[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute scoring and enumeration (forward / inside)\n",
    "log_partition = dist.partition\n",
    "max_score = dist.log_prob(dist.argmax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAJ1klEQVR4nO3dzYtdBx2H8edrkiYmvqObJsVm4QuhqJWhrRZcNEJ9w25cVKigm2y0rSJIdeM/IKUupBCqbix2EbsQKUbxZeEmOE0DmqSFUrVNUzEitlIxSfHnYq4Q83bP3NzjmfnxfKCQmXt7+iXk6blzcudMqgpJfbxu6gGSlsuopWaMWmrGqKVmjFpqZusYB70u22sHu8Y4tCTgX7zKuTqbyz02StQ72MWt2T/GoSUBR+oXV3zMl99SM0YtNWPUUjNGLTVj1FIzRi01MyjqJB9L8kySZ5M8MPYoSYubG3WSLcB3gI8D+4DPJtk39jBJixlypr4FeLaqnquqc8BjwF3jzpK0qCFR7wZeuODjU7PP/Y8kB5KsJlk9z9ll7ZO0Tku7UFZVB6tqpapWtrF9WYeVtE5Don4RuOGCj/fMPidpAxoS9W+BdyXZm+Q64G7gx+POkrSoud+lVVWvJfkScBjYAnyvqo6PvkzSQgZ962VVPQE8MfIWSUvgO8qkZoxaasaopWaMWmrGqKVmRrnx4Lvf908OHz629OPeef0Hln5MgMOnN89WaR7P1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTVj1FIzRi01Y9RSM6mqpR/0TXlb3Zr9Sz+upDVH6he8Un/L5R7zTC01Y9RSM0YtNWPUUjNGLTVj1FIzRi01MzfqJDck+VWSE0mOJ7n//zFM0mKG/Cjb14CvVtXRJG8Enkzy86o6MfI2SQuYe6auqpeq6ujs1/8ATgK7xx4maTHr+qHzSW4EbgaOXOaxA8ABgB3sXMI0SYsYfKEsyRuAHwFfrqpXLn68qg5W1UpVrWxj+zI3SlqHQVEn2cZa0I9W1ePjTpJ0LYZc/Q7wXeBkVT04/iRJ12LImfp24HPAHUmOzf75xMi7JC1o7oWyqvoNcNnv25S08fiOMqkZo5aaMWqpGaOWmlnXO8qmdvj0sVGOe+f1H1j6MTfTVvXimVpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaamZT3U10M91Jc6yt3qVU83imlpoxaqkZo5aaMWqpGaOWmjFqqRmjlpoZHHWSLUmeSvKTMQdJujbrOVPfD5wca4ik5RgUdZI9wCeBR8adI+laDT1TPwR8Dfj3lZ6Q5ECS1SSr5zm7lHGS1m9u1Ek+Bfylqp682vOq6mBVrVTVyja2L22gpPUZcqa+Hfh0kj8CjwF3JPnBqKskLWxu1FX19araU1U3AncDv6yqe0ZfJmkh/j211My6vp+6qn4N/HqUJZKWwjO11IxRS80YtdSMUUvNGLXUzKa6m6i86+eYutyp1TO11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMdxOVZrrcqdUztdSMUUvNGLXUjFFLzRi11IxRS80YtdTMoKiTvCXJoSRPJzmZ5ENjD5O0mKFvPvk28NOq+kyS64CdI26SdA3mRp3kzcBHgM8DVNU54Ny4syQtasjL773AGeD7SZ5K8kiSXRc/KcmBJKtJVs9zdulDJQ0zJOqtwAeBh6vqZuBV4IGLn1RVB6tqpapWtrF9yTMlDTUk6lPAqao6Mvv4EGuRS9qA5kZdVX8GXkjyntmn9gMnRl0laWFDr37fCzw6u/L9HPCF8SZJuhaDoq6qY8DKyFskLYHvKJOaMWqpGaOWmjFqqRmjlprxbqIa1eHTx0Y5bpc7f47BM7XUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzXjjQY1qrBsEjnFDwy43M/RMLTVj1FIzRi01Y9RSM0YtNWPUUjNGLTUzKOokX0lyPMnvk/wwyY6xh0lazNyok+wG7gNWquomYAtw99jDJC1m6MvvrcDrk2wFdgKnx5sk6VrMjbqqXgS+BTwPvAS8XFU/u/h5SQ4kWU2yep6zy18qaZAhL7/fCtwF7AWuB3Yluefi51XVwapaqaqVbWxf/lJJgwx5+f1R4A9VdaaqzgOPAx8ed5akRQ2J+nngtiQ7kwTYD5wcd5akRQ35mvoIcAg4Cvxu9u8cHHmXpAUN+n7qqvom8M2Rt0haAt9RJjVj1FIzRi01Y9RSM0YtNePdRAWMc3dOGO8OnZvpzp9j/N7ecuc/r/iYZ2qpGaOWmjFqqRmjlpoxaqkZo5aaMWqpGaOWmjFqqRmjlpoxaqkZo5aaMWqpGaOWmjFqqRmjlpoxaqkZo5aaMWqpGaOWmjFqqZlU1fIPmpwB/jTgqW8H/rr0AePZTHs301bYXHs3wtZ3VtU7LvfAKFEPlWS1qlYmG7BOm2nvZtoKm2vvRt/qy2+pGaOWmpk66s32w+s3097NtBU2194NvXXSr6klLd/UZ2pJS2bUUjOTRZ3kY0meSfJskgem2jFPkhuS/CrJiSTHk9w/9aYhkmxJ8lSSn0y95WqSvCXJoSRPJzmZ5ENTb7qaJF+Z/Tn4fZIfJtkx9aaLTRJ1ki3Ad4CPA/uAzybZN8WWAV4DvlpV+4DbgC9u4K0Xuh84OfWIAb4N/LSq3gu8nw28Oclu4D5gpapuArYAd0+76lJTnalvAZ6tqueq6hzwGHDXRFuuqqpeqqqjs1//g7U/dLunXXV1SfYAnwQemXrL1SR5M/AR4LsAVXWuqv4+7aq5tgKvT7IV2AmcnnjPJaaKejfwwgUfn2KDhwKQ5EbgZuDItEvmegj4GvDvqYfMsRc4A3x/9qXCI0l2TT3qSqrqReBbwPPAS8DLVfWzaVddygtlAyV5A/Aj4MtV9crUe64kyaeAv1TVk1NvGWAr8EHg4aq6GXgV2MjXV97K2ivKvcD1wK4k90y76lJTRf0icMMFH++ZfW5DSrKNtaAfrarHp94zx+3Ap5P8kbUva+5I8oNpJ13RKeBUVf33lc8h1iLfqD4K/KGqzlTVeeBx4MMTb7rEVFH/FnhXkr1JrmPtYsOPJ9pyVUnC2td8J6vqwan3zFNVX6+qPVV1I2u/r7+sqg13NgGoqj8DLyR5z+xT+4ETE06a53ngtiQ7Z38u9rMBL+xtneI/WlWvJfkScJi1K4jfq6rjU2wZ4Hbgc8Dvkhybfe4bVfXEhJs6uRd4dPY/9+eAL0y854qq6kiSQ8BR1v5W5Ck24FtGfZuo1IwXyqRmjFpqxqilZoxaasaopWaMWmrGqKVm/gO6kifnSx+yRwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Compute samples \n",
    "show(dist.sample((1,)).detach()[0, 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAALy0lEQVR4nO3d7Wud9R3H8c8nJ6n2ZvN+QtPWFidKEYYSRC34wPpAp5vgxqigOAfrk3k7QXRP/AMmogwRijdsU3RQ3SbTuQnqYAyLsRW0jbJaXW+ss3XaWm+WJvnuQTLoWtNz9fT380q+e79AMDnx2y/xvHudXLlyxREhAHn0tb0AgLKIGkiGqIFkiBpIhqiBZPprDD3pxL5Ysrj86JF9pxSfKUkDu1x8Zt/+8eIzJUm1vlsxXmffqDS3Bvd16gyu8P/s84l9Go0vvvSJWyXqJYv79Zc/nlp87tDfflx8piQtXDOn+Mxjd+4rPlOSNLq/ztx/fVxl7MSevVXm1tC3YH6VuTE+UXzmy/uenvYxXn4DyRA1kAxRA8kQNZAMUQPJEDWQTKOobV9q+y3bm23fUXspAL3rGrXtjqT7JV0mabmkq20vr70YgN40OVKfJ2lzRGyJiFFJT0i6su5aAHrVJOpBSdsOeHv71Pv+h+3VtodtD3/4YfkraAA0U+xEWUSsiYihiBg66STOvwFtaVLfDkmLD3h70dT7AMxATaJ+RdIZtpfZniNplaTpryYH0KquP6UVEWO2b5D0J0kdSQ9HxMbqmwHoSaMfvYyIZyU9W3kXAAVwRgtIhqiBZIgaSIaogWSIGkimyo0H33z/VK34+U+Lz51b6Z577/yg/OAzfjm3+ExJ6t81VmUudxOVJj7/osrcGB0tP3Ni+kuxOVIDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8lUuZvonL1jGnx+d/G57151cvGZknTa78vPfPv7x5YfKumbv5n+LpJHo/PhR1XmVhNRfuZEhZlSnV0PgyM1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kEzXqG0vtv2i7U22N9q++atYDEBvmlx8MibptohYb/trkl61/XxEbKq8G4AedD1SR8TOiFg/9e+fSBqRNFh7MQC9OaKvqW0vlXSOpHVf8thq28O2h0fHPyuzHYAj1jhq2wskPSnplojYe/DjEbEmIoYiYmhOZ17JHQEcgUZR2x7QZNCPRcRTdVcCcDSanP22pIckjUTEPfVXAnA0mhypV0i6VtLFtl+b+ufblfcC0KOu39KKiL9K8lewC4ACuKIMSIaogWSIGkiGqIFkqtx4UOMT8t5Pi49d8sxA8ZmStOUHXy8+87RnRovPlKS3vze/ytwzHyj/OZAk791XZW6Mj5cf2pfjfDBHaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogmTp3E5WkiOIjOx/XuTPlst+W/7tt86o6d/1c+NeJKnO3XrWwytzFv67zu8rjs8+Lz3SnzjFuosadT8emf4gjNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZBM46htd2xvsP2HmgsBODpHcqS+WdJIrUUAlNEoatuLJF0u6cG66wA4Wk2P1PdKul3StNco2l5te9j28OhE+Uv4ADTTNWrbV0j6ICJePdzHRcSaiBiKiKE5fXOLLQjgyDQ5Uq+Q9F3b70p6QtLFth+tuhWAnnWNOiLujIhFEbFU0ipJL0TENdU3A9ATvk8NJHNEP08dES9JeqnKJgCK4EgNJEPUQDJEDSRD1EAyRA0kU+VuojE2pvEPdhef64E6Nz/tfPhR8Zln/eKE4jMlaccVg1XmHv92hTteShq5a2mVuWc8+u/iMwd2flx8piR1FiwoPtPvD0z7GEdqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiCZOrfnjFCM7S8/drzOHS/dN1p+ZqVdB5+eqDJ3y3WLq8xd+nT554EkvfudecVnnv5E+eeBJPXt3lNl7rR/3lf6pwGojqiBZIgaSIaogWSIGkiGqIFkiBpIplHUto+3vdb2m7ZHbF9QezEAvWl68cl9kp6LiO/bniOp/Hf+ARTRNWrbx0m6SNIPJSkiRiXVufQGwFFr8vJ7maRdkh6xvcH2g7bnH/xBtlfbHrY9vF/lfyE4gGaaRN0v6VxJD0TEOZI+lXTHwR8UEWsiYigihgZ0TOE1ATTVJOrtkrZHxLqpt9dqMnIAM1DXqCPifUnbbJ859a6VkjZV3QpAz5qe/b5R0mNTZ763SLq+3koAjkajqCPiNUlDlXcBUABXlAHJEDWQDFEDyRA1kAxRA8nUuZuoJEVUmFnnDp0RLj90tM5dNGPP3ipzlzz3SZW5f7+uzs/+LHlmrPjMbZeeUHymJC35XYXnwu7pj8ccqYFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIpt6NB/s61Ub/X5uocENHSZ19dX6n+LIn51SZu+Xa8jNPfLnO53b75d8oPnP00enT5UgNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJNMoatu32t5o+w3bj9s+tvZiAHrTNWrbg5JukjQUEWdL6khaVXsxAL1p+vK7X9Jc2/2S5kl6r95KAI5G16gjYoekuyVtlbRT0p6I+PPBH2d7te1h28P7VeeSQwDdNXn5fYKkKyUtk7RQ0nzb1xz8cRGxJiKGImJoQMeU3xRAI01efl8i6Z2I2BUR+yU9JenCumsB6FWTqLdKOt/2PNuWtFLSSN21APSqydfU6yStlbRe0utT/82aynsB6FGjn6eOiLsk3VV5FwAFcEUZkAxRA8kQNZAMUQPJEDWQTJW7ibrTUee4rxefG6OjxWdKUoyNFZ/pTqW7qfa5ylh/Uedze8w/91WZe/qv5hef+c6Pvig+U5LeXvlI8ZnnPb9r2sc4UgPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyTgiyg+1d0n6R4MPPVnS7uIL1DOb9p1Nu0qza9+ZsOtpEXHKlz1QJeqmbA9HxFBrCxyh2bTvbNpVml37zvRdefkNJEPUQDJtRz3bfnn9bNp3Nu0qza59Z/SurX5NDaC8to/UAAojaiCZ1qK2fantt2xvtn1HW3t0Y3ux7Rdtb7K90fbNbe/UhO2O7Q22/9D2Lodj+3jba22/aXvE9gVt73Q4tm+deh68Yftx28e2vdPBWonadkfS/ZIuk7Rc0tW2l7exSwNjkm6LiOWSzpf0kxm864FuljTS9hIN3CfpuYg4S9K3NIN3tj0o6SZJQxFxtqSOpFXtbnWoto7U50naHBFbImJU0hOSrmxpl8OKiJ0RsX7q3z/R5JNusN2tDs/2IkmXS3qw7V0Ox/Zxki6S9JAkRcRoRHzc7lZd9Uuaa7tf0jxJ77W8zyHainpQ0rYD3t6uGR6KJNleKukcSeva3aSreyXdLmmi7UW6WCZpl6RHpr5UeNB2+d8mX0hE7JB0t6StknZK2hMRf253q0Nxoqwh2wskPSnplojY2/Y+07F9haQPIuLVtndpoF/SuZIeiIhzJH0qaSafXzlBk68ol0laKGm+7Wva3epQbUW9Q9LiA95eNPW+Gcn2gCaDfiwinmp7ny5WSPqu7Xc1+WXNxbYfbXelaW2XtD0i/vvKZ60mI5+pLpH0TkTsioj9kp6SdGHLOx2irahfkXSG7WW252jyZMPTLe1yWLatya/5RiLinrb36SYi7oyIRRGxVJOf1xciYsYdTSQpIt6XtM32mVPvWilpU4srdbNV0vm25009L1ZqBp7Y62/jD42IMds3SPqTJs8gPhwRG9vYpYEVkq6V9Lrt16be97OIeLbFnTK5UdJjU3+5b5F0fcv7TCsi1tleK2m9Jr8rskEz8JJRLhMFkuFEGZAMUQPJEDWQDFEDyRA1kAxRA8kQNZDMfwC/SrvqAKZqWAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAK7klEQVR4nO3d32/V9R3H8deLliLg/MHUOFs2MDEuxGTDNE4l8ULMFDWabLvARJN5w4VT0ZkYXUz8B5zRC2NCUG8keoFcEGfUxV/JvGisYKJQnYgIRVQMKj8USst7Fz1LGNCebw/fz77tO89HQkJPD2/e0T75nvM9p986IgQgj1lNLwCgXkQNJEPUQDJEDSRD1EAy3SWGLlgwK3r7umqf+++DF9Q+U5J69rn2mR4t9KpCoVcrPFbqVZAyc+PISJG5M8VhHdJIHDnlF26RqHv7urThH+fVPvf6d++qfaYk9T0/u/aZc749XPtMSfLYsSJzZx0otO/oWJG5o9t3FJk7UwzEGxN+joffQDJEDSRD1EAyRA0kQ9RAMkQNJFMpats32P7E9jbbD5VeCkDn2kZtu0vSU5JWSFoi6TbbS0ovBqAzVY7UV0jaFhHbI2JE0ouSbi27FoBOVYm6V9Ku4z4ebt32P2yvsj1oe3DfvjLvegLQXm0nyiJiTUT0R0T/ggWcfwOaUqW+3ZIWHvdxX+s2ANNQlajfk3SJ7cW2eyStlLSx7FoAOtX2u7QiYtT23ZJek9Ql6dmI2FJ8MwAdqfStlxHxiqRXCu8CoAac0QKSIWogGaIGkiFqIBmiBpIpcuHBz76+UH/8+19rnzu/0D9BO2+q/+J4F6+v/2KGkjR7309F5vqnI0XmxuEyczExjtRAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJFribafXBMF777Q+1zd954du0zJekX77j2mV+smFP7TElatPFYkbk93x8sMldRZl9MjCM1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kEzbqG0vtP2W7a22t9he/f9YDEBnqrz5ZFTSAxGxyfbPJL1v+58RsbXwbgA60PZIHRF7ImJT6/cHJA1J6i29GIDOTOk5te1FkpZKGjjF51bZHrQ9eHT0x3q2AzBllaO2faaklyTdFxH7T/x8RKyJiP6I6J/dPa/OHQFMQaWobc/WeNDrImJD2ZUAnI4qZ78t6RlJQxHxePmVAJyOKkfqZZLukHSt7Q9av24svBeADrV9SSsi/iWp/m84BlAE7ygDkiFqIBmiBpIhaiCZIhce9LHQrINHap/b98ah2mdK0o5b6n+zzAXvlbng3s4Vc4vMvfiF+UXm+nD9XweYHEdqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiCZIlcTlSTNqv8n9XQVuEKpJP3y1a7aZ27/w+zaZ0rSWdvK/ASk4RULisxduHG0yFx9/U2ZuQlwpAaSIWogGaIGkiFqIBmiBpIhaiAZogaSqRy17S7bm22/XHIhAKdnKkfq1ZKGSi0CoB6VorbdJ+kmSWvLrgPgdFU9Uj8h6UFJE/4kddurbA/aHhwZ+7GW5QBMXduobd8s6ZuIeH+y+0XEmojoj4j+nq55tS0IYGqqHKmXSbrF9g5JL0q61vbzRbcC0LG2UUfEwxHRFxGLJK2U9GZE3F58MwAd4XVqIJkpfT91RLwt6e0imwCoBUdqIBmiBpIhaiAZogaSIWogmTJXEx0bk7/bX/tYu8yVNOd8d7D2mZeuPbP2mZI0fH2Zq34e/nkUmfvZHecXmbvokW1F5mbAkRpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSKbM1URtaU5P/XNHx+qfKUmjo7WPnHXgp9pnStJF7xwoMvfTuwr8/5L06e+fKTL3+kd+W2RuBhypgWSIGkiGqIFkiBpIhqiBZIgaSIaogWQqRW37HNvrbX9se8j2VaUXA9CZqm8+eVLSqxHxJ9s9kuYV3AnAaWgbte2zJV0j6c+SFBEjkkbKrgWgU1Uefi+WtFfSc7Y3215re/6Jd7K9yvag7cGRsTJvkQTQXpWouyVdLunpiFgq6ZCkh068U0SsiYj+iOjv6Zpb85oAqqoS9bCk4YgYaH28XuORA5iG2kYdEV9J2mX70tZNyyVtLboVgI5VPft9j6R1rTPf2yXdWW4lAKejUtQR8YGk/sK7AKgB7ygDkiFqIBmiBpIhaiAZogaScUTUPvQsL4jfeXntcwGMG4g3tD/2+VSf40gNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAylaK2fb/tLbY/sv2C7TNKLwagM22jtt0r6V5J/RFxmaQuSStLLwagM1UffndLmmu7W9I8SV+WWwnA6WgbdUTslvSYpJ2S9kj6ISJeP/F+tlfZHrQ9eFRH6t8UQCVVHn6fK+lWSYslXSRpvu3bT7xfRKyJiP6I6J+tOfVvCqCSKg+/r5P0eUTsjYijkjZIurrsWgA6VSXqnZKutD3PtiUtlzRUdi0AnarynHpA0npJmyR92PozawrvBaBD3VXuFBGPSnq08C4AasA7yoBkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZBwR9Q+190r6osJdz5P0be0LlDOT9p1Ju0oza9/psOuvIuL8U32iSNRV2R6MiP7GFpiimbTvTNpVmln7TvddefgNJEPUQDJNRz3Tfnj9TNp3Ju0qzax9p/WujT6nBlC/po/UAGpG1EAyjUVt+wbbn9jeZvuhpvZox/ZC22/Z3mp7i+3VTe9Uhe0u25ttv9z0LpOxfY7t9bY/tj1k+6qmd5qM7ftbXwcf2X7B9hlN73SiRqK23SXpKUkrJC2RdJvtJU3sUsGopAciYomkKyX9ZRrverzVkoaaXqKCJyW9GhG/lvQbTeOdbfdKuldSf0RcJqlL0spmtzpZU0fqKyRti4jtETEi6UVJtza0y6QiYk9EbGr9/oDGv+h6m91qcrb7JN0kaW3Tu0zG9tmSrpH0jCRFxEhEfN/sVm11S5pru1vSPElfNrzPSZqKulfSruM+HtY0D0WSbC+StFTSQLObtPWEpAclHWt6kTYWS9or6bnWU4W1tuc3vdREImK3pMck7ZS0R9IPEfF6s1udjBNlFdk+U9JLku6LiP1N7zMR2zdL+iYi3m96lwq6JV0u6emIWCrpkKTpfH7lXI0/olws6SJJ823f3uxWJ2sq6t2SFh73cV/rtmnJ9myNB70uIjY0vU8byyTdYnuHxp/WXGv7+WZXmtCwpOGI+O8jn/Uaj3y6uk7S5xGxNyKOStog6eqGdzpJU1G/J+kS24tt92j8ZMPGhnaZlG1r/DnfUEQ83vQ+7UTEwxHRFxGLNP7f9c2ImHZHE0mKiK8k7bJ9aeum5ZK2NrhSOzslXWl7XuvrYrmm4Ym97ib+0ogYtX23pNc0fgbx2YjY0sQuFSyTdIekD21/0LrtbxHxSoM7ZXKPpHWtf9y3S7qz4X0mFBEDttdL2qTxV0U2axq+ZZS3iQLJcKIMSIaogWSIGkiGqIFkiBpIhqiBZIgaSOY/s4BvM195fOoAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Padding/Masking built into library.\n",
    "dist = DependencyCRF(vals, lengths=torch.tensor([10, 7]))\n",
    "show(dist.marginals[0])\n",
    "plt.show()\n",
    "show(dist.marginals[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAKoElEQVR4nO3dTYhd9R3G8eeZuZPoRImhtdK80Myi2qaCRC6ihioYF1pFNy4iKOgmm6pRBNFuXHYjVgsihKgbgy7GLESsWlAphTY4TgRNRtsQbV6MmNY0pkaTzOTXxdxCmmTmnrlz/p47P74fEDIv/n3UfD137tw5OiIEII+BpgcAqBdRA8kQNZAMUQPJEDWQTKvIoUuHY/ElS2s/N44UmSsV+AaAp+o/U5JisMy5Q4ePlzkYRXw7dVQnTn3rc32sSCWLL1mqX/z+ntrP/e6NH9V+piQNnKy/6sVHynyr8MSF5/z3OG+XbNtd5FyU8ZevRmf8GA+/gWSIGkiGqIFkiBpIhqiBZIgaSKZS1LZvsv2J7d22Hy09CkDvukZte1DSM5JulrRG0p2215QeBqA3Va7UV0naHRF7IuKEpJcl3V52FoBeVYl6haR9p729v/O+/2N7o+0x22OTR47VtQ/AHNX2RFlEbI6IdkS0W0uH6zoWwBxVifqApFWnvb2y8z4AfahK1O9J+qntEduLJG2Q9GrZWQB61fWntCJi0vZ9kt6UNCjp+YjYWXwZgJ5U+tHLiHhd0uuFtwCoAa8oA5IhaiAZogaSIWogGaIGkily48ELh47r+h/XfyO7P7TK3HjQp+q/mV+4zI0HTw2WufFgHD1a5FwUMjXz7Wq5UgPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRS5m6gVGvLMdzsEUA5XaiAZogaSIWogGaIGkiFqIBmiBpIhaiCZrlHbXmX7Hdu7bO+0ven7GAagN1VefDIp6eGIGLd9oaT3bf8xInYV3gagB12v1BFxMCLGO78+KmlC0orSwwD0Zk5fU9teLWmtpO3n+NhG22O2x44dPl7POgBzVjlq2xdIekXSgxHx9Zkfj4jNEdGOiPbwssV1bgQwB5Witj2k6aC3RsS2spMAzEeVZ78t6TlJExHxZPlJAOajypV6naS7Jd1g+4POH78qvAtAj7p+Sysi/izJ38MWADXgFWVAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRTOWrbg7Z32H6t5CAA8zOXK/UmSROlhgCoR6Woba+UdIukLWXnAJivqlfqpyQ9IunUTJ9ge6PtMdtjxw4fr2UcgLnrGrXtWyV9GRHvz/Z5EbE5ItoR0R5etri2gQDmpsqVep2k22x/JullSTfYfrHoKgA96xp1RDwWESsjYrWkDZLejoi7ii8D0BO+Tw0k05rLJ0fEu5LeLbIEQC24UgPJEDWQDFEDyRA1kAxRA8nM6dnvqo79raXx9RfXfu6Kb8ZrP7OUmJrxFbXzM+Aix061f17k3L/fM1Tk3OFP6z939Qt7aj9Tkn7311dqP/OOW47O+DGu1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkXuJhqTU5r611cljkYh0Spzl9KB8yeLnHtqUYG7lA6UucZdOrSk9jPP88x9caUGkiFqIBmiBpIhaiAZogaSIWogGaIGkqkUte2LbI/a/tj2hO1rSg8D0JuqLz55WtIbEXGH7UWShgtuAjAPXaO2vVTSdZLukaSIOCHpRNlZAHpV5eH3iKRDkl6wvcP2Fttnve7N9kbbY7bHTup47UMBVFMl6pakKyU9GxFrJX0j6dEzPykiNkdEOyLaQ1pc80wAVVWJer+k/RGxvfP2qKYjB9CHukYdEV9I2mf7ss671kvaVXQVgJ5Vffb7fklbO89875F0b7lJAOajUtQR8YGkduEtAGrAK8qAZIgaSIaogWSIGkiGqIFkitxNFAtPuMzdRO0ocm6UmZsCV2ogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkil348GBwWJHo34xUOZOfgODk0XOjVaBGxoOlLnG/em7+s88OsvfPldqIBmiBpIhaiAZogaSIWogGaIGkiFqIJlKUdt+yPZO2x/Zfsn2eaWHAehN16htr5D0gKR2RFwuaVDShtLDAPSm6sPvlqTzbbckDUv6vNwkAPPRNeqIOCDpCUl7JR2UdCQi3jrz82xvtD1me+ykjte/FEAlVR5+L5N0u6QRScslLbF915mfFxGbI6IdEe0hLa5/KYBKqjz8vlHSpxFxKCJOStom6dqyswD0qkrUeyVdbXvYtiWtlzRRdhaAXlX5mnq7pFFJ45I+7Pw5mwvvAtCjSj9PHRGPS3q88BYANeAVZUAyRA0kQ9RAMkQNJEPUQDJF7ibqRUNqLV9e4mgUcnjVoiLnXj/ycZFz35m6tPYzT676Qe1nStJvr/hl7Wd+8Z9XZ/wYV2ogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBlHRP2H2ock/aPCp/5Q0j9rH1DOQtq7kLZKC2tvP2z9SURcfK4PFIm6KttjEdFubMAcLaS9C2mrtLD29vtWHn4DyRA1kEzTUS+0/3n9Qtq7kLZKC2tvX29t9GtqAPVr+koNoGZEDSTTWNS2b7L9ie3dth9takc3tlfZfsf2Lts7bW9qelMVtgdt77D9WtNbZmP7Itujtj+2PWH7mqY3zcb2Q53fBx/Zfsn2eU1vOlMjUdselPSMpJslrZF0p+01TWypYFLSwxGxRtLVkn7dx1tPt0nSRNMjKnha0hsR8TNJV6iPN9teIekBSe2IuFzSoKQNza46W1NX6qsk7Y6IPRFxQtLLkm5vaMusIuJgRIx3fn1U07/pVjS7ana2V0q6RdKWprfMxvZSSddJek6SIuJERPy72VVdtSSdb7slaVjS5w3vOUtTUa+QtO+0t/erz0ORJNurJa2VtL3ZJV09JekRSaeaHtLFiKRDkl7ofKmwxfaSpkfNJCIOSHpC0l5JByUdiYi3ml11Np4oq8j2BZJekfRgRHzd9J6Z2L5V0pcR8X7TWypoSbpS0rMRsVbSN5L6+fmVZZp+RDkiabmkJbbvanbV2ZqK+oCkVae9vbLzvr5ke0jTQW+NiG1N7+linaTbbH+m6S9rbrD9YrOTZrRf0v6I+N8jn1FNR96vbpT0aUQcioiTkrZJurbhTWdpKur3JP3U9ojtRZp+suHVhrbMyrY1/TXfREQ82fSebiLisYhYGRGrNf3P9e2I6LuriSRFxBeS9tm+rPOu9ZJ2NTipm72SrrY93Pl9sV59+MReq4m/aERM2r5P0puafgbx+YjY2cSWCtZJulvSh7Y/6LzvNxHxeoObMrlf0tbOf9z3SLq34T0ziojttkcljWv6uyI71IcvGeVlokAyPFEGJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJPNfeeZOBpDWxAgAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Many other structured prediction approaches\n",
    "chain = torch.zeros(2, 10, 10, 10) + 1e-5\n",
    "chain[:, :, :, :] = vals.unsqueeze(-1).exp()\n",
    "chain[:, :, :, :] += torch.eye(10, 10).view(1, 1, 10, 10) \n",
    "chain[:, 0, :, 0] = 1\n",
    "chain[:, -1,9, :] = 1\n",
    "chain = chain.log()\n",
    "\n",
    "dist = LinearChainCRF(chain)\n",
    "show(dist.marginals.detach()[0].sum(-1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Library\n",
    "\n",
    "Current distributions implemented:\n",
    "\n",
    "* LinearChainCRF \n",
    "* SemiMarkovCRF \n",
    "* DependencyCRF \n",
    "* TreeCRF \n",
    "\n",
    "\n",
    "Extensions:\n",
    "\n",
    "* Integration with `torchtext`, `pytorch-transformers`, `dgl`\n",
    "* Adapters for generative structured models (CFG / HMM / HSMM)\n",
    "* Common tree structured parameterizations TreeLSTM / SpanLSTM\n",
    "\n",
    "Design Strategy:\n",
    "\n",
    "1) Minimal efficient python implementatations. \n",
    "2) Batched for GPU.\n",
    "3) Code can be ported to other backends\n",
    "\n",
    "\n",
    "## Low-level API: \n",
    "\n",
    "Everything implemented through semiring dynamic programming. \n",
    "\n",
    "* Log Marginals\n",
    "* Max and MAP computation\n",
    "* Sampling through specialized backprop\n",
    "* Entropy and first-order semirings. \n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Examples\n",
    "\n",
    "* BERT <a href=\"https://github.com/harvardnlp/pytorch-struct/blob/master/notebooks/BertTagger.ipynb\">Part-of-Speech</a> \n",
    "* BERT <a href=\"https://github.com/harvardnlp/pytorch-struct/blob/master/notebooks/BertDependencies.ipynb\">Dependency Parsing</a>\n",
    "* Unsupervised Learning (to come)\n",
    "* Structured VAE (to come)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
